Amazon SageMakerでのTensorFlowの学習ログをTensorBoardで確認する

Amazon SageMakerでのTensorFlowの学習ログをTensorBoardで確認する

Clock Icon2020.01.29

この記事は公開されてから1年以上経過しています。情報が古い可能性がありますので、ご注意ください。

SageMakerはTensorFlowの学習に対応しているため、他の環境用に作成した学習用スクリプトを最小限の修正もしくはそのままで使用できます。 今回はSageMakerでのTensorFlowによるモデルの学習情報をTensorBoardで確認する方法を紹介します。

概要

SageMakerのスクリプトモードでTensorFlowのスクリプトを学習させる際には、model_dir引数を受け取れるようにする必要があります。 TensorBoardで見たいログを書き出す場所として、スクリプトへの引数として受け取ったmodel_dirを設定することで、学習中や学習後にTensorBoardから参照することができます。model_dirはS3スキームのURIです。TensorFlowやTensorBoardがS3とのファイルのやりとりをサポートしているため、チェックポイントやロスなどのデータなどをS3上の指定場所へ自動的に書き出したり、TensorBoardがその場所のデータを読み込むことができます。

例えば、引数として受け取ったmodel_dirtf.keras.callbacks.TensorBoardlog_dirtf.summary.create_file_writerlogdirtf.compat.v1.estimator.Estimatormodel_dir等に設定することで、指定したデータが自動的にS3に書き出されます。

TensorBoardで実行する際には次のように認証情報とリージョン情報を環境変数として設定した上でTensorBoardを実行することで、S3上にある学習ログを参照してくれます。

export AWS_ACCESS_KEY_ID=HOGEHOGEHOGE
export AWS_SECRET_ACCESS_KEY=aaaaaaaaaaaa/aaaaaa
# export AWS_SESSION_TOKEN=AAAAAAAAAA.....  # 必要に応じてAWS_SESSION_TOKENも設定
export AWS_REGION=ap-northeast-1 # 対象のS3バケットがあるリージョン
tensorboard --logdir s3://bucket/path/to/log

やってみる

Amazon SageMaker Examplesで紹介されているノートブックの内容を元に試してみます。

学習

学習データはすでにS3にアップされたものを使用するため、今回のモデルの学習に必要なのはTensorFlow用のEstimatorに学習用スクリプトであるエントリポイントやその他パラメータを設定するだけです。

SageMaker SDKのTensorFlowEstimatorでmodel_dirを引数として、どこに学習ログを出力するかを設定することができます。未設定の場合はデフォルトのS3 URIとして学習ジョブ名が接頭辞のパスが指定されます。今回はmodel_dirを未設定で試します。

import os
import sagemaker
from sagemaker import get_execution_role
from sagemaker.tensorflow import TensorFlow
sagemaker_session = sagemaker.Session()
role = get_execution_role()
region = sagemaker_session.boto_session.region_name

# パブリックアクセス可能なS3にある前処理済みのデータを学習に使用する
training_data_uri = 's3://sagemaker-sample-data-{}/tensorflow/mnist'.format(region) 

mnist_estimator = TensorFlow(entry_point='mnist.py',
                             role=role,
                             train_instance_count=2,
                             train_instance_type='ml.p2.xlarge',
                             framework_version='1.14',
                             py_version='py3',
                             distributions={'parameter_server': {'enabled': True}})
mnist_estimator.fit(training_data_uri)

fitを叩くことでSageMaker上で学習が実行されて、ログが表示されます。その中には次のように環境変数情報も含まれています。

学習が完了したら、TensorBoardから参照するためにmodel_dirを確認しておきます。

mnist_estimator.model_dir

学習用スクリプト

エントリポイントとして設定した学習用スクリプトを一部抜粋して確認します。

まずは引数や環境変数をパースしている箇所です。引数としてmodel_dirが、環境変数にはSM_という接頭辞がついた幾つかの変数が設定されています。その中で必要なものをパースして利用します。

設定される環境変数の一覧については以下のドキュメントをご覧ください。

def _parse_args():

    parser = argparse.ArgumentParser()

    # Data, model, and output directories
    # model_dir is always passed in from SageMaker. By default this is a S3 path under the default bucket.
    parser.add_argument('--model_dir', type=str)
    parser.add_argument('--sm-model-dir', type=str, default=os.environ.get('SM_MODEL_DIR'))
    parser.add_argument('--train', type=str, default=os.environ.get('SM_CHANNEL_TRAINING'))
    parser.add_argument('--hosts', type=list, default=json.loads(os.environ.get('SM_HOSTS')))
    parser.add_argument('--current-host', type=str, default=os.environ.get('SM_CURRENT_HOST'))

    return parser.parse_known_args()

Estimatorにパースしたmodel_dirを設定します。

mnist_classifier = tf.estimator.Estimator(
    model_fn=cnn_model_fn, model_dir=args.model_dir)

TensorBoardで確認

認証情報とリージョンを環境変数として設定し、TensorBoardを起動させます。

export AWS_ACCESS_KEY_ID=HOGEHOGEHOGE
export AWS_SECRET_ACCESS_KEY=aaaaaaaaaaaa/aaaaaa
# export AWS_SESSION_TOKEN=AAAAAAAAAA.....  # 必要に応じてAWS_SESSION_TOKENも設定
export AWS_REGION=ap-northeast-1 # 対象のS3バケットがあるリージョン
tensorboard --logdir s3://sagemaker-ap-northeast-1-111111111/tensorflow-training-2020-01-29-10-31-40-321/model

ブラウザからhttp://localhost:6006/にアクセスすると、次のように学習で保存されたデータが表示されます。

モデルも確認出来ます。

さいごに

SageMaker上でスクリプトモードで学習させたTensorFlowの学習情報をTensorBoardで確認する方法について紹介しました。TensorBoardを使うことで、SageMakerに依存しない他の環境と共通の方法でモデルや学習状況を把握することが可能です。今回は紹介できませんでしたが、SageMaker DebuggerもTensorBoardと連携することができます。また試し次第、ご紹介したいと思います。

参考

Share this article

facebook logohatena logotwitter logo

© Classmethod, Inc. All rights reserved.